import math
from pathlib import Path

import numpy as np
import pandas as pd
import statsmodels.formula.api as smf
from linearmodels.iv import IV2SLS


MISSING_CODES = {7, 8, 9, 77, 88, 99, 777, 888, 999, 7777, 8888, 9999}


def clean_numeric(series: pd.Series) -> pd.Series:
    s = pd.to_numeric(series, errors="coerce")
    s = s.mask(s.isin(MISSING_CODES))
    return s


def recode_yes_no(series: pd.Series) -> pd.Series:
    s = clean_numeric(series)
    return s.map({1: 1.0, 2: 0.0})


def zscore(series: pd.Series) -> pd.Series:
    s = pd.to_numeric(series, errors="coerce")
    m = s.mean()
    sd = s.std(ddof=0)
    if pd.isna(sd) or sd <= 0:
        return s * 0
    return (s - m) / sd


def build_dataset(root: Path) -> pd.DataFrame:
    ess = pd.read_parquet(root / "outputs" / "ess_r8_r11_min.parquet")
    reg = pd.read_csv(root / "data_external" / "broadband_region_year_eurostat_isoc_r_broad_h.csv")

    ess = ess[ess["regunit"].isin([1, 2])].copy()
    ess["region"] = ess["region"].astype(str).str.upper().str.strip()
    ess = ess[ess["region"] != "99999"].copy()

    reg["region"] = reg["region"].astype(str).str.upper().str.strip()
    reg["year"] = pd.to_numeric(reg["year"], errors="coerce").astype("Int64")
    reg["broadband_pc_hh"] = pd.to_numeric(reg.get("broadband_pc_hh"), errors="coerce")
    reg["internet_access_pc_hh"] = pd.to_numeric(reg.get("internet_access_pc_hh"), errors="coerce")

    # Merge on region × survey_year
    df = ess.merge(reg, left_on=["region", "survey_year"], right_on=["region", "year"], how="left")

    # Endogenous internet use (use frequency; time is available but noisier)
    df["netusoft"] = clean_numeric(df["netusoft"])
    df["netustm"] = clean_numeric(df["netustm"])
    df["log_netustm"] = (df["netustm"].clip(lower=0)).map(lambda x: math.log1p(x) if pd.notna(x) else x)

    # Instruments (region-year)
    df["z_broadband_pc_hh"] = zscore(df["broadband_pc_hh"])
    df["z_internet_access_pc_hh"] = zscore(df["internet_access_pc_hh"])

    # Outcomes
    df["y_vote"] = recode_yes_no(df["vote"])
    df["y_sgnptit"] = recode_yes_no(df["sgnptit"])
    df["y_pbldmn"] = recode_yes_no(df["pbldmn"])
    df["y_bctprd"] = recode_yes_no(df["bctprd"])
    df["y_contplt"] = recode_yes_no(df["contplt"])
    df["y_badge"] = recode_yes_no(df["badge"])
    df["y_pstplonl"] = recode_yes_no(df["pstplonl"]) if "pstplonl" in df.columns else np.nan

    # Participation index (mean of available non-institutional items)
    part_cols = ["y_sgnptit", "y_pbldmn", "y_bctprd", "y_contplt", "y_badge", "y_pstplonl"]
    df["y_part_index"] = df[part_cols].mean(axis=1, skipna=True)
    df.loc[df[part_cols].isna().all(axis=1), "y_part_index"] = np.nan

    # Controls
    df["agea"] = clean_numeric(df["agea"])
    df["gndr"] = clean_numeric(df["gndr"])
    df["eduyrs"] = clean_numeric(df["eduyrs"])
    df["hinctnta"] = clean_numeric(df["hinctnta"])

    # Education group (simple, interpretable)
    df["edu_high"] = np.where(df["eduyrs"].notna() & (df["eduyrs"] > 12), 1.0, np.nan)
    df.loc[df["eduyrs"].notna() & (df["eduyrs"] <= 12), "edu_high"] = 0.0

    # FE key
    df["ct"] = df["cntry"].astype(str) + "_" + df["survey_year"].astype("Int64").astype(str)

    # Keep rows with instruments and key vars
    df = df[df[["z_broadband_pc_hh", "z_internet_access_pc_hh"]].notna().any(axis=1)].copy()
    return df


def make_fe_dummies(df: pd.DataFrame, fe_cols: list[str]) -> pd.DataFrame:
    out = df.copy()
    for c in fe_cols:
        out[c] = out[c].astype(str)
    return pd.get_dummies(out, columns=fe_cols, drop_first=True)


def run_itt(df: pd.DataFrame, y: str, out_lines: list[str]) -> None:
    # ITT: y ~ instruments + controls + country×year FE; cluster by region
    d = df[[y, "z_broadband_pc_hh", "z_internet_access_pc_hh", "agea", "gndr", "eduyrs", "hinctnta", "ct", "region"]].dropna()
    if d.empty:
        out_lines.append(f"[ITT SKIP] {y}: no data")
        return
    fml = f"{y} ~ z_broadband_pc_hh + z_internet_access_pc_hh + agea + gndr + eduyrs + hinctnta + C(ct)"
    res = smf.ols(fml, data=d).fit(cov_type="cluster", cov_kwds={"groups": d["region"]})
    out_lines.append(f"[ITT] {y}: n={len(d):,} regions={d['region'].nunique()} ct={d['ct'].nunique()}")
    out_lines.append(
        f"  coef(z_broadband_pc_hh)={res.params.get('z_broadband_pc_hh', float('nan')):.4f} se={res.bse.get('z_broadband_pc_hh', float('nan')):.4f}"
    )
    out_lines.append(
        f"  coef(z_internet_access_pc_hh)={res.params.get('z_internet_access_pc_hh', float('nan')):.4f} se={res.bse.get('z_internet_access_pc_hh', float('nan')):.4f}"
    )


def run_iv(df: pd.DataFrame, y: str, out_lines: list[str], add_region_fe: bool) -> None:
    # IV: y ~ controls + FE + [netusoft ~ z_broadband + z_internet_access]; cluster by region
    cols = [
        y,
        "netusoft",
        "z_broadband_pc_hh",
        "z_internet_access_pc_hh",
        "agea",
        "gndr",
        "eduyrs",
        "hinctnta",
        "ct",
        "region",
        "survey_year",
        "cntry",
    ]
    d = df[cols].dropna().copy()
    if d.empty:
        out_lines.append(f"[IV SKIP] {y}: no data")
        return

    if add_region_fe:
        # Alternative robustness spec to avoid rank issues with sparse region-year coverage:
        # region FE + country FE + survey-year FE (instead of country×year FE).
        region_year_n = d.groupby("region")["survey_year"].nunique()
        keep_region = region_year_n[region_year_n >= 2].index
        d = d[d["region"].isin(keep_region)].copy()
        if d.empty:
            out_lines.append(f"[IV regionFE+countryFE+yearFE] {y}: SKIP (insufficient multi-year regions)")
            return
        fe_cols = ["region", "cntry", "survey_year"]
        spec = "regionFE+countryFE+yearFE"
    else:
        fe_cols = ["ct"]
        spec = "ctFE"

    d = make_fe_dummies(d, fe_cols)

    yv = d[y]
    endog = d["netusoft"]
    instr = d[["z_broadband_pc_hh", "z_internet_access_pc_hh"]]
    exog = d.drop(
        columns=[y, "netusoft", "z_broadband_pc_hh", "z_internet_access_pc_hh", "region", "survey_year", "cntry"],
        errors="ignore",
    )
    exog = exog.loc[:, ~exog.columns.duplicated()]
    # Drop any constant columns (can appear after filtering)
    nunique = exog.nunique(dropna=False)
    exog = exog.loc[:, nunique > 1]

    # Cluster key: original region string isn't present after dummies, so keep it separately.
    clusters = df.loc[d.index, "region"]
    try:
        mod = IV2SLS(yv, exog=exog, endog=endog, instruments=instr)
        res = mod.fit(cov_type="clustered", clusters=clusters)
        fs = res.first_stage.diagnostics
        fs_row = fs.loc["netusoft"] if "netusoft" in fs.index else fs.iloc[0]
        fs_f = float(fs_row.get("f.stat", float("nan")))
        fs_p = float(fs_row.get("f.pval", float("nan")))
        out_lines.append(
            f"[IV {spec}] {y}: n={len(d):,} regions={clusters.nunique()} coef(netusoft)={res.params.get('netusoft', float('nan')):.4f} se={res.std_errors.get('netusoft', float('nan')):.4f} | first_stage_F={fs_f:.2f} p={fs_p:.4g}"
        )
    except Exception as e:
        out_lines.append(f"[IV {spec}] {y}: ERROR {type(e).__name__}: {e}")


def run_iv_by_group(df: pd.DataFrame, y: str, out_lines: list[str]) -> None:
    for g, name in [(0.0, "edu_low(<=12)"), (1.0, "edu_high(>12)")]:
        sub = df[df["edu_high"] == g].copy()
        if sub.empty:
            continue
        tmp = []
        run_iv(sub, y, tmp, add_region_fe=False)
        out_lines.append(f"[IV by group] {y} {name}: " + (tmp[-1] if tmp else "SKIP"))


def main() -> None:
    root = Path(__file__).resolve().parents[1]
    out_path = root / "outputs" / "iv_region_year_inequality.txt"

    df = build_dataset(root)
    years = sorted(int(x) for x in df["survey_year"].dropna().unique())
    countries = sorted(df["cntry"].dropna().astype(str).unique().tolist())

    lines: list[str] = []
    lines.append("Region-year IV/ITT (Eurostat isoc_r_broad_h -> internet use -> participation)")
    lines.append(f"Years in sample: {years}")
    lines.append(f"Countries in sample: {len(countries)}")
    lines.append("")

    outcomes = [
        ("y_part_index", "Participation index"),
        ("y_vote", "Voted"),
        ("y_sgnptit", "Signed petition"),
        ("y_pbldmn", "Demonstration"),
        ("y_bctprd", "Boycott"),
        ("y_contplt", "Contacted politician"),
    ]

    for y, label in outcomes:
        lines.append(f"== {label} ({y}) ==")
        run_itt(df, y, lines)
        run_iv(df, y, lines, add_region_fe=False)
        run_iv(df, y, lines, add_region_fe=True)
        run_iv_by_group(df, y, lines)
        lines.append("")

    out_path.write_text("\n".join(lines) + "\n", encoding="utf-8")
    print(f"Wrote: {out_path}")


if __name__ == "__main__":
    main()
